Explicando Modelos de Aprendizaje Automático

(Creado por Daniel Hernández Mota)

Temas:

  • SHAP
  • LIME
  • Anchors.

Enfoque:

  • Modelo de clasificación para datos tabulares

Instalar las librerías necesarias:

In [1]:
!pip install shap
!pip install anchor-exp
!pip install lime
!pip install lightgbm
!pip install seaborn
Requirement already satisfied: shap in ./venv/lib/python3.7/site-packages (0.38.1)
Requirement already satisfied: tqdm>4.25.0 in ./venv/lib/python3.7/site-packages (from shap) (4.58.0)
Requirement already satisfied: numpy in ./venv/lib/python3.7/site-packages (from shap) (1.20.1)
Requirement already satisfied: scikit-learn in ./venv/lib/python3.7/site-packages (from shap) (0.24.1)
Requirement already satisfied: pandas in ./venv/lib/python3.7/site-packages (from shap) (1.2.2)
Requirement already satisfied: slicer==0.0.7 in ./venv/lib/python3.7/site-packages (from shap) (0.0.7)
Requirement already satisfied: numba in ./venv/lib/python3.7/site-packages (from shap) (0.52.0)
Requirement already satisfied: scipy in ./venv/lib/python3.7/site-packages (from shap) (1.6.1)
Requirement already satisfied: cloudpickle in ./venv/lib/python3.7/site-packages (from shap) (1.6.0)
Requirement already satisfied: setuptools in ./venv/lib/python3.7/site-packages (from numba->shap) (53.1.0)
Requirement already satisfied: llvmlite<0.36,>=0.35.0 in ./venv/lib/python3.7/site-packages (from numba->shap) (0.35.0)
Requirement already satisfied: python-dateutil>=2.7.3 in ./venv/lib/python3.7/site-packages (from pandas->shap) (2.8.1)
Requirement already satisfied: pytz>=2017.3 in ./venv/lib/python3.7/site-packages (from pandas->shap) (2021.1)
Requirement already satisfied: six>=1.5 in ./venv/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas->shap) (1.15.0)
Requirement already satisfied: joblib>=0.11 in ./venv/lib/python3.7/site-packages (from scikit-learn->shap) (1.0.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in ./venv/lib/python3.7/site-packages (from scikit-learn->shap) (2.1.0)
Requirement already satisfied: anchor-exp in ./venv/lib/python3.7/site-packages (0.0.2.0)
Requirement already satisfied: scikit-learn>=0.22 in ./venv/lib/python3.7/site-packages (from anchor-exp) (0.24.1)
Requirement already satisfied: spacy in ./venv/lib/python3.7/site-packages (from anchor-exp) (3.0.3)
Requirement already satisfied: lime in ./venv/lib/python3.7/site-packages (from anchor-exp) (0.2.0.1)
Requirement already satisfied: numpy in ./venv/lib/python3.7/site-packages (from anchor-exp) (1.20.1)
Requirement already satisfied: scipy in ./venv/lib/python3.7/site-packages (from anchor-exp) (1.6.1)
Requirement already satisfied: joblib>=0.11 in ./venv/lib/python3.7/site-packages (from scikit-learn>=0.22->anchor-exp) (1.0.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in ./venv/lib/python3.7/site-packages (from scikit-learn>=0.22->anchor-exp) (2.1.0)
Requirement already satisfied: matplotlib in ./venv/lib/python3.7/site-packages (from lime->anchor-exp) (3.3.4)
Requirement already satisfied: tqdm in ./venv/lib/python3.7/site-packages (from lime->anchor-exp) (4.58.0)
Requirement already satisfied: scikit-image>=0.12 in ./venv/lib/python3.7/site-packages (from lime->anchor-exp) (0.18.1)
Requirement already satisfied: networkx>=2.0 in ./venv/lib/python3.7/site-packages (from scikit-image>=0.12->lime->anchor-exp) (2.5)
Requirement already satisfied: tifffile>=2019.7.26 in ./venv/lib/python3.7/site-packages (from scikit-image>=0.12->lime->anchor-exp) (2021.2.26)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in ./venv/lib/python3.7/site-packages (from scikit-image>=0.12->lime->anchor-exp) (8.1.0)
Requirement already satisfied: PyWavelets>=1.1.1 in ./venv/lib/python3.7/site-packages (from scikit-image>=0.12->lime->anchor-exp) (1.1.1)
Requirement already satisfied: imageio>=2.3.0 in ./venv/lib/python3.7/site-packages (from scikit-image>=0.12->lime->anchor-exp) (2.9.0)
Requirement already satisfied: cycler>=0.10 in ./venv/lib/python3.7/site-packages (from matplotlib->lime->anchor-exp) (0.10.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in ./venv/lib/python3.7/site-packages (from matplotlib->lime->anchor-exp) (2.4.7)
Requirement already satisfied: python-dateutil>=2.1 in ./venv/lib/python3.7/site-packages (from matplotlib->lime->anchor-exp) (2.8.1)
Requirement already satisfied: kiwisolver>=1.0.1 in ./venv/lib/python3.7/site-packages (from matplotlib->lime->anchor-exp) (1.3.1)
Requirement already satisfied: six in ./venv/lib/python3.7/site-packages (from cycler>=0.10->matplotlib->lime->anchor-exp) (1.15.0)
Requirement already satisfied: decorator>=4.3.0 in ./venv/lib/python3.7/site-packages (from networkx>=2.0->scikit-image>=0.12->lime->anchor-exp) (4.4.2)
Requirement already satisfied: catalogue<2.1.0,>=2.0.1 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (2.0.1)
Requirement already satisfied: packaging>=20.0 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (20.9)
Requirement already satisfied: setuptools in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (53.1.0)
Requirement already satisfied: typing-extensions>=3.7.4 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (3.7.4.3)
Requirement already satisfied: thinc<8.1.0,>=8.0.0 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (8.0.1)
Requirement already satisfied: importlib-metadata>=0.20 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (3.7.0)
Requirement already satisfied: requests<3.0.0,>=2.13.0 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (2.25.1)
Requirement already satisfied: blis<0.8.0,>=0.4.0 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (0.7.4)
Requirement already satisfied: typer<0.4.0,>=0.3.0 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (0.3.2)
Requirement already satisfied: srsly<3.0.0,>=2.4.0 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (2.4.0)
Requirement already satisfied: jinja2 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (2.11.3)
Requirement already satisfied: cymem<2.1.0,>=2.0.2 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (2.0.5)
Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.0 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (3.0.1)
Requirement already satisfied: pathy in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (0.4.0)
Requirement already satisfied: preshed<3.1.0,>=3.0.2 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (3.0.5)
Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (1.0.5)
Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (0.8.2)
Requirement already satisfied: pydantic<1.8.0,>=1.7.1 in ./venv/lib/python3.7/site-packages (from spacy->anchor-exp) (1.7.3)
Requirement already satisfied: zipp>=0.5 in ./venv/lib/python3.7/site-packages (from importlib-metadata>=0.20->spacy->anchor-exp) (3.4.0)
Requirement already satisfied: chardet<5,>=3.0.2 in ./venv/lib/python3.7/site-packages (from requests<3.0.0,>=2.13.0->spacy->anchor-exp) (4.0.0)
Requirement already satisfied: idna<3,>=2.5 in ./venv/lib/python3.7/site-packages (from requests<3.0.0,>=2.13.0->spacy->anchor-exp) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in ./venv/lib/python3.7/site-packages (from requests<3.0.0,>=2.13.0->spacy->anchor-exp) (2020.12.5)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./venv/lib/python3.7/site-packages (from requests<3.0.0,>=2.13.0->spacy->anchor-exp) (1.26.3)
Requirement already satisfied: click<7.2.0,>=7.1.1 in ./venv/lib/python3.7/site-packages (from typer<0.4.0,>=0.3.0->spacy->anchor-exp) (7.1.2)
Requirement already satisfied: MarkupSafe>=0.23 in ./venv/lib/python3.7/site-packages (from jinja2->spacy->anchor-exp) (1.1.1)
Requirement already satisfied: smart-open<4.0.0,>=2.2.0 in ./venv/lib/python3.7/site-packages (from pathy->spacy->anchor-exp) (3.0.0)
Requirement already satisfied: lime in ./venv/lib/python3.7/site-packages (0.2.0.1)
Requirement already satisfied: numpy in ./venv/lib/python3.7/site-packages (from lime) (1.20.1)
Requirement already satisfied: scipy in ./venv/lib/python3.7/site-packages (from lime) (1.6.1)
Requirement already satisfied: scikit-image>=0.12 in ./venv/lib/python3.7/site-packages (from lime) (0.18.1)
Requirement already satisfied: matplotlib in ./venv/lib/python3.7/site-packages (from lime) (3.3.4)
Requirement already satisfied: scikit-learn>=0.18 in ./venv/lib/python3.7/site-packages (from lime) (0.24.1)
Requirement already satisfied: tqdm in ./venv/lib/python3.7/site-packages (from lime) (4.58.0)
Requirement already satisfied: imageio>=2.3.0 in ./venv/lib/python3.7/site-packages (from scikit-image>=0.12->lime) (2.9.0)
Requirement already satisfied: tifffile>=2019.7.26 in ./venv/lib/python3.7/site-packages (from scikit-image>=0.12->lime) (2021.2.26)
Requirement already satisfied: networkx>=2.0 in ./venv/lib/python3.7/site-packages (from scikit-image>=0.12->lime) (2.5)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in ./venv/lib/python3.7/site-packages (from scikit-image>=0.12->lime) (8.1.0)
Requirement already satisfied: PyWavelets>=1.1.1 in ./venv/lib/python3.7/site-packages (from scikit-image>=0.12->lime) (1.1.1)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in ./venv/lib/python3.7/site-packages (from matplotlib->lime) (2.4.7)
Requirement already satisfied: cycler>=0.10 in ./venv/lib/python3.7/site-packages (from matplotlib->lime) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in ./venv/lib/python3.7/site-packages (from matplotlib->lime) (1.3.1)
Requirement already satisfied: python-dateutil>=2.1 in ./venv/lib/python3.7/site-packages (from matplotlib->lime) (2.8.1)
Requirement already satisfied: six in ./venv/lib/python3.7/site-packages (from cycler>=0.10->matplotlib->lime) (1.15.0)
Requirement already satisfied: decorator>=4.3.0 in ./venv/lib/python3.7/site-packages (from networkx>=2.0->scikit-image>=0.12->lime) (4.4.2)
Requirement already satisfied: threadpoolctl>=2.0.0 in ./venv/lib/python3.7/site-packages (from scikit-learn>=0.18->lime) (2.1.0)
Requirement already satisfied: joblib>=0.11 in ./venv/lib/python3.7/site-packages (from scikit-learn>=0.18->lime) (1.0.1)
Requirement already satisfied: lightgbm in ./venv/lib/python3.7/site-packages (3.1.1)
Requirement already satisfied: numpy in ./venv/lib/python3.7/site-packages (from lightgbm) (1.20.1)
Requirement already satisfied: wheel in ./venv/lib/python3.7/site-packages (from lightgbm) (0.36.2)
Requirement already satisfied: scipy in ./venv/lib/python3.7/site-packages (from lightgbm) (1.6.1)
Requirement already satisfied: scikit-learn!=0.22.0 in ./venv/lib/python3.7/site-packages (from lightgbm) (0.24.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in ./venv/lib/python3.7/site-packages (from scikit-learn!=0.22.0->lightgbm) (2.1.0)
Requirement already satisfied: joblib>=0.11 in ./venv/lib/python3.7/site-packages (from scikit-learn!=0.22.0->lightgbm) (1.0.1)
Requirement already satisfied: seaborn in ./venv/lib/python3.7/site-packages (0.11.1)
Requirement already satisfied: numpy>=1.15 in ./venv/lib/python3.7/site-packages (from seaborn) (1.20.1)
Requirement already satisfied: scipy>=1.0 in ./venv/lib/python3.7/site-packages (from seaborn) (1.6.1)
Requirement already satisfied: pandas>=0.23 in ./venv/lib/python3.7/site-packages (from seaborn) (1.2.2)
Requirement already satisfied: matplotlib>=2.2 in ./venv/lib/python3.7/site-packages (from seaborn) (3.3.4)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in ./venv/lib/python3.7/site-packages (from matplotlib>=2.2->seaborn) (2.4.7)
Requirement already satisfied: python-dateutil>=2.1 in ./venv/lib/python3.7/site-packages (from matplotlib>=2.2->seaborn) (2.8.1)
Requirement already satisfied: cycler>=0.10 in ./venv/lib/python3.7/site-packages (from matplotlib>=2.2->seaborn) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in ./venv/lib/python3.7/site-packages (from matplotlib>=2.2->seaborn) (1.3.1)
Requirement already satisfied: pillow>=6.2.0 in ./venv/lib/python3.7/site-packages (from matplotlib>=2.2->seaborn) (8.1.0)
Requirement already satisfied: six in ./venv/lib/python3.7/site-packages (from cycler>=0.10->matplotlib>=2.2->seaborn) (1.15.0)
Requirement already satisfied: pytz>=2017.3 in ./venv/lib/python3.7/site-packages (from pandas>=0.23->seaborn) (2021.1)

Importar las librerías que se ban a usar

In [2]:
# Lo necesario:
import pandas as pd
import numpy as np

# Sklearn
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import roc_auc_score
from sklearn import datasets

# Visualización
import seaborn as sns
import matplotlib.pyplot as plt

# Modelo
from lightgbm import LGBMClassifier

# Explicabilidad:
import shap
import lime
import lime.lime_tabular
from anchor import anchor_tabular
/home/dhdzmota/.local/lib/python3.7/site-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm
/home/dhdzmota/.local/lib/python3.7/site-packages/numba/errors.py:137: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9
  warnings.warn(msg)

Descripción del conjunto de datos:

El siguiente conjunto de datos contiene información acerca de un diagnóstico de cancer de mama. La descripción del conjunto de datos se puede obtener del siguiente link).

En este caso, las variables fueron determinadas a través de una imagen digitalizada. El objetivo será generar un clasificador para predecir si el cancer de mama será beningno o maligno dado un conjunto de ciertas características. Y, posteriormente realizar la explicación de esta predicción de manera local y global.

In [3]:
# Lectura de datos 
data = datasets.load_breast_cancer()
In [4]:
# Vamos guardando un dataframe para las respuestas
print(data['DESCR'])
.. _breast_cancer_dataset:

Breast cancer wisconsin (diagnostic) dataset
--------------------------------------------

**Data Set Characteristics:**

    :Number of Instances: 569

    :Number of Attributes: 30 numeric, predictive attributes and the class

    :Attribute Information:
        - radius (mean of distances from center to points on the perimeter)
        - texture (standard deviation of gray-scale values)
        - perimeter
        - area
        - smoothness (local variation in radius lengths)
        - compactness (perimeter^2 / area - 1.0)
        - concavity (severity of concave portions of the contour)
        - concave points (number of concave portions of the contour)
        - symmetry
        - fractal dimension ("coastline approximation" - 1)

        The mean, standard error, and "worst" or largest (mean of the three
        worst/largest values) of these features were computed for each image,
        resulting in 30 features.  For instance, field 0 is Mean Radius, field
        10 is Radius SE, field 20 is Worst Radius.

        - class:
                - WDBC-Malignant
                - WDBC-Benign

    :Summary Statistics:

    ===================================== ====== ======
                                           Min    Max
    ===================================== ====== ======
    radius (mean):                        6.981  28.11
    texture (mean):                       9.71   39.28
    perimeter (mean):                     43.79  188.5
    area (mean):                          143.5  2501.0
    smoothness (mean):                    0.053  0.163
    compactness (mean):                   0.019  0.345
    concavity (mean):                     0.0    0.427
    concave points (mean):                0.0    0.201
    symmetry (mean):                      0.106  0.304
    fractal dimension (mean):             0.05   0.097
    radius (standard error):              0.112  2.873
    texture (standard error):             0.36   4.885
    perimeter (standard error):           0.757  21.98
    area (standard error):                6.802  542.2
    smoothness (standard error):          0.002  0.031
    compactness (standard error):         0.002  0.135
    concavity (standard error):           0.0    0.396
    concave points (standard error):      0.0    0.053
    symmetry (standard error):            0.008  0.079
    fractal dimension (standard error):   0.001  0.03
    radius (worst):                       7.93   36.04
    texture (worst):                      12.02  49.54
    perimeter (worst):                    50.41  251.2
    area (worst):                         185.2  4254.0
    smoothness (worst):                   0.071  0.223
    compactness (worst):                  0.027  1.058
    concavity (worst):                    0.0    1.252
    concave points (worst):               0.0    0.291
    symmetry (worst):                     0.156  0.664
    fractal dimension (worst):            0.055  0.208
    ===================================== ====== ======

    :Missing Attribute Values: None

    :Class Distribution: 212 - Malignant, 357 - Benign

    :Creator:  Dr. William H. Wolberg, W. Nick Street, Olvi L. Mangasarian

    :Donor: Nick Street

    :Date: November, 1995

This is a copy of UCI ML Breast Cancer Wisconsin (Diagnostic) datasets.
https://goo.gl/U2Uwz2

Features are computed from a digitized image of a fine needle
aspirate (FNA) of a breast mass.  They describe
characteristics of the cell nuclei present in the image.

Separating plane described above was obtained using
Multisurface Method-Tree (MSM-T) [K. P. Bennett, "Decision Tree
Construction Via Linear Programming." Proceedings of the 4th
Midwest Artificial Intelligence and Cognitive Science Society,
pp. 97-101, 1992], a classification method which uses linear
programming to construct a decision tree.  Relevant features
were selected using an exhaustive search in the space of 1-4
features and 1-3 separating planes.

The actual linear program used to obtain the separating plane
in the 3-dimensional space is that described in:
[K. P. Bennett and O. L. Mangasarian: "Robust Linear
Programming Discrimination of Two Linearly Inseparable Sets",
Optimization Methods and Software 1, 1992, 23-34].

This database is also available through the UW CS ftp server:

ftp ftp.cs.wisc.edu
cd math-prog/cpo-dataset/machine-learn/WDBC/

.. topic:: References

   - W.N. Street, W.H. Wolberg and O.L. Mangasarian. Nuclear feature extraction 
     for breast tumor diagnosis. IS&T/SPIE 1993 International Symposium on 
     Electronic Imaging: Science and Technology, volume 1905, pages 861-870,
     San Jose, CA, 1993.
   - O.L. Mangasarian, W.N. Street and W.H. Wolberg. Breast cancer diagnosis and 
     prognosis via linear programming. Operations Research, 43(4), pages 570-577, 
     July-August 1995.
   - W.H. Wolberg, W.N. Street, and O.L. Mangasarian. Machine learning techniques
     to diagnose breast cancer from fine-needle aspirates. Cancer Letters 77 (1994) 
     163-171.
In [5]:
data.keys()
Out[5]:
dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename'])
In [6]:
data['target'][:20]
Out[6]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
In [7]:
X = data['data']
y = data['target']
In [8]:
# Response variable
sns.distplot(y, kde=False)
plt.title("Distribution of the response variable")
plt.xlabel("Is beningn?")
plt.ylabel('Count')
plt.show()

Entrenamiento del modelo:

In [9]:
# Train test-split:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, stratify=y)
In [10]:
model = LGBMClassifier()
# Creamos un modelo y generamos la busqueda de hiperparametros simple.

parameters = {
    'learning_rate': [0.01, 0.1, 0.001],
    'max_depth':[3, 4, 5],
    'n_estimators': [10, 50, 100]}

gs = GridSearchCV(model, parameters)
gs.fit(X_train, y_train, verbose=False)
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
Out[10]:
GridSearchCV(estimator=LGBMClassifier(),
             param_grid={'learning_rate': [0.01, 0.1, 0.001],
                         'max_depth': [3, 4, 5],
                         'n_estimators': [10, 50, 100]})
In [11]:
# Determinamos el mejor score y parámetros
gs.best_params_, gs.best_score_
Out[11]:
({'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 100},
 0.968295739348371)
In [12]:
# Los seleccionamos y entrenamos un modelo con estos valores
model.set_params(**gs.best_params_)
model.fit(X_train, y_train)
Out[12]:
LGBMClassifier(max_depth=5)
In [13]:
# Predecimos los valores de entrenamiento
pred = model.predict(X_train)
pred_prob = model.predict_proba(X_train)
In [14]:
sns.distplot(pred_prob[:,1], bins=60, kde=False)
plt.title("Score distribution (Train-set)")
plt.show()
In [15]:
# Determinamos el score y vemos el desempeño
roc_auc_score(y_train, pred_prob[:,1])
Out[15]:
1.0
In [16]:
# Predecimos los valores de prueba, los guardamos en nuestro dataframe
pred_prob_test = model.predict_proba(X_test)
pred_test = model.predict(X_test)
In [17]:
roc_auc_score(y_test, pred_prob_test[:,1])
Out[17]:
0.989828185938653
In [18]:
sns.distplot(pred_prob_test[:,1], bins=60, kde=False)
plt.title("Score distribution (test-set)")
plt.show()
In [19]:
malign_index = list(y_test).index(0)

Explicando el modelo...

SHAP:

In [20]:
# Generamos la explicación por SHAP 
explainer_shap = shap.TreeExplainer(model)
In [21]:
shap_values = explainer_shap.shap_values(X_test)[1]
shap.summary_plot(shap_values, X_test, feature_names=data['feature_names'])
LightGBM binary classifier with TreeExplainer shap values output has changed to a list of ndarray
In [22]:
top_inds = np.argsort(-np.sum(np.abs(shap_values), 0))

for i in range(5):
    shap.dependence_plot(top_inds[i], shap_values, X_test, feature_names=data['feature_names'])
In [23]:
shap.initjs()

shap.force_plot(
    base_value=0,
    shap_values=shap_values,
    features=X_test,
    feature_names=data['feature_names'],
    link='logit')
Out[23]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [24]:
shap.force_plot(
    base_value=0,
    shap_values=shap_values[malign_index,:],
    features=X_test[malign_index,:],
    feature_names=data['feature_names'],
    link='logit')
Out[24]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [25]:
shap.decision_plot(0, shap_values, features=data['feature_names'], link='logit')

LIME:

In [26]:
explainer_lime = lime.lime_tabular.LimeTabularExplainer(
    X_train,
    feature_names=data['feature_names'],
    class_names=['malign', 'beningn'],
    discretize_continuous=True)
In [27]:
exp_lime = explainer_lime.explain_instance(
    X_test[malign_index,:], model.predict_proba, num_features=2,)
In [28]:
exp_lime.show_in_notebook(show_table=True, show_all=False)
In [29]:
exp_lime = explainer_lime.explain_instance(
    X_test[malign_index,:], model.predict_proba, num_features=10,)
exp_lime.show_in_notebook(show_table=True, show_all=False)

Anchors

In [30]:
explainer_anchor = anchor_tabular.AnchorTabularExplainer(
    ['malign', 'beningn'],
    data['feature_names'],
    X_train)
In [31]:
exp_anchor = explainer_anchor.explain_instance(
     np.array(X_test[malign_index,:]), model.predict, threshold=0.95)

Al poner el valor de umbral en 0.95, garantizamos con alta probabilidad que la precisión de nuestra explicación estará arriba de 0.95 (Predicciones en instancias donde se cumple el anchor será similar a la predicción original 95% de las veces')

In [32]:
'Anchor: ' +' AND '.join(exp_anchor.names())
Out[32]:
'Anchor: perimeter error > 1.55 AND worst area > 1141.00'
In [33]:
(('Precision:',exp_anchor.precision()),
('Coverage:', exp_anchor.coverage()))
Out[33]:
(('Precision:', 0.9904214559386973), ('Coverage:', 0.2512))
In [34]:
exp_anchor.show_in_notebook()
In [35]:
exp_anchor = explainer_anchor.explain_instance(
     np.array(X_test[malign_index,:]), model.predict, threshold=0.99)
In [36]:
'Anchor: ' +' AND '.join(exp_anchor.names())
Out[36]:
'Anchor: area error > 25.00 AND worst radius > 19.19'
In [37]:
(('Precision:',exp_anchor.precision()),
('Coverage:', exp_anchor.coverage()))
Out[37]:
(('Precision:', 0.991701244813278), ('Coverage:', 0.2485))